{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('pair.json') as fopen:\n",
    "    data = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(self, size_layer, num_layers, embedded_size,\n",
    "                 dict_size, learning_rate, dropout):\n",
    "        \n",
    "        def cells(size, reuse=False):\n",
    "            cell = tf.nn.rnn_cell.LSTMCell(size,initializer=tf.orthogonal_initializer(),reuse=reuse)\n",
    "            return tf.contrib.rnn.DropoutWrapper(cell,output_keep_prob=dropout)\n",
    "        \n",
    "        def birnn(inputs, scope):\n",
    "            with tf.variable_scope(scope, reuse = tf.AUTO_REUSE):\n",
    "                for n in range(num_layers):\n",
    "                    (out_fw, out_bw), (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn(\n",
    "                        cell_fw = cells(size_layer // 2),\n",
    "                        cell_bw = cells(size_layer // 2),\n",
    "                        inputs = inputs,\n",
    "                        dtype = tf.float32,\n",
    "                        scope = 'bidirectional_rnn_%d'%(n))\n",
    "                    inputs = tf.concat((out_fw, out_bw), 2)\n",
    "                return inputs[:,-1]\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None])\n",
    "        self.batch_size = tf.shape(self.X)[0]\n",
    "        encoder_embeddings = tf.Variable(tf.random_uniform([dict_size, embedded_size], -1, 1))\n",
    "        embedded_left = tf.nn.embedding_lookup(encoder_embeddings, self.X)\n",
    "        \n",
    "        self.out = birnn(embedded_left, 'left')\n",
    "        self.logits = tf.layers.dense(self.out, 2)\n",
    "        self.cost = tf.reduce_mean(\n",
    "            tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "                logits = self.logits, labels = self.Y\n",
    "            )\n",
    "        )\n",
    "        \n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(self.cost)\n",
    "        correct_pred = tf.equal(\n",
    "            tf.argmax(self.logits, 1, output_type = tf.int32), self.Y\n",
    "        )\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 256\n",
    "num_layers = 2\n",
    "embedded_size = 256\n",
    "learning_rate = 1e-3\n",
    "batch_size = 128\n",
    "dropout = 1.0\n",
    "vocab_size = 30000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-4-c2607cc1ae8a>:6: LSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.\n",
      "WARNING:tensorflow:\n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "WARNING:tensorflow:From <ipython-input-4-c2607cc1ae8a>:17: bidirectional_dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `keras.layers.Bidirectional(keras.layers.RNN(cell))`, which is equivalent to this API\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn.py:464: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `keras.layers.RNN(cell)`, which is equivalent to this API\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:958: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.add_weight` method instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:962: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
      "WARNING:tensorflow:From <ipython-input-4-c2607cc1ae8a>:28: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.Dense instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/layers/core.py:187: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(size_layer,num_layers,embedded_size,vocab_size,learning_rate,dropout)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['left_train', 'label_train', 'left_test', 'label_test'])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X_left = data['left_train']\n",
    "train_Y = data['label_train']\n",
    "test_X_left = data['left_test']\n",
    "test_Y = data['label_test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "pad_sequences = tf.keras.preprocessing.sequence.pad_sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2046/2046 [10:43<00:00,  3.18it/s, accuracy=0.714, cost=0.563]\n",
      "test minibatch loop: 100%|██████████| 105/105 [00:12<00:00,  8.71it/s, accuracy=0.747, cost=0.492]\n",
      "train minibatch loop:   0%|          | 0/2046 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.000000, current acc: 0.713587\n",
      "time taken: 655.4127008914948\n",
      "epoch: 0, training loss: 0.647283, training acc: 0.582361, valid loss: 0.543352, valid acc: 0.713587\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2046/2046 [10:36<00:00,  3.22it/s, accuracy=0.833, cost=0.418]\n",
      "test minibatch loop: 100%|██████████| 105/105 [00:12<00:00,  8.59it/s, accuracy=0.759, cost=0.443]\n",
      "train minibatch loop:   0%|          | 0/2046 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.713587, current acc: 0.742646\n",
      "time taken: 648.4790985584259\n",
      "epoch: 0, training loss: 0.507708, training acc: 0.743618, valid loss: 0.507135, valid acc: 0.742646\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2046/2046 [12:53<00:00,  2.65it/s, accuracy=0.952, cost=0.243]\n",
      "test minibatch loop: 100%|██████████| 105/105 [00:19<00:00,  5.46it/s, accuracy=0.771, cost=0.435]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 792.6765551567078\n",
      "epoch: 0, training loss: 0.440407, training acc: 0.789272, valid loss: 0.525370, valid acc: 0.738147\n",
      "\n",
      "break epoch:0\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "\n",
    "EARLY_STOPPING, CURRENT_CHECKPOINT, CURRENT_ACC, EPOCH = 1, 0, 0, 0\n",
    "\n",
    "while True:\n",
    "    lasttime = time.time()\n",
    "    if CURRENT_CHECKPOINT == EARLY_STOPPING:\n",
    "        print('break epoch:%d\\n' % (EPOCH))\n",
    "        break\n",
    "\n",
    "    train_acc, train_loss, test_acc, test_loss = [], [], [], []\n",
    "    pbar = tqdm(range(0, len(train_X_left), batch_size), desc='train minibatch loop')\n",
    "    for i in pbar:\n",
    "        index = min(i+batch_size,len(train_X_left))\n",
    "        batch_x_left = train_X_left[i:index]\n",
    "        batch_y = train_Y[i:index]\n",
    "        batch_x_left = pad_sequences(batch_x_left, padding='post')\n",
    "        acc, loss, _ = sess.run([model.accuracy, model.cost, model.optimizer], \n",
    "                           feed_dict = {model.X : batch_x_left, \n",
    "                                        model.Y : batch_y})\n",
    "        assert not np.isnan(loss)\n",
    "        train_loss.append(loss)\n",
    "        train_acc.append(acc)\n",
    "        pbar.set_postfix(cost=loss, accuracy = acc)\n",
    "    \n",
    "    pbar = tqdm(range(0, len(test_X_left), batch_size), desc='test minibatch loop')\n",
    "    for i in pbar:\n",
    "        index = min(i+batch_size,len(test_X_left))\n",
    "        batch_x_left = test_X_left[i:index]\n",
    "        batch_y = test_Y[i:index]\n",
    "        batch_x_left = pad_sequences(batch_x_left, padding='post')\n",
    "        acc, loss = sess.run([model.accuracy, model.cost], \n",
    "                           feed_dict = {model.X : batch_x_left,\n",
    "                                        model.Y : batch_y})\n",
    "        \n",
    "        test_loss.append(loss)\n",
    "        test_acc.append(acc)\n",
    "        pbar.set_postfix(cost=loss, accuracy = acc)\n",
    "    \n",
    "    train_loss = np.mean(train_loss)\n",
    "    train_acc = np.mean(train_acc)\n",
    "    test_loss = np.mean(test_loss)\n",
    "    test_acc = np.mean(test_acc)\n",
    "    \n",
    "    if test_acc > CURRENT_ACC:\n",
    "        print(\n",
    "            'epoch: %d, pass acc: %f, current acc: %f'\n",
    "            % (EPOCH, CURRENT_ACC, test_acc)\n",
    "        )\n",
    "        CURRENT_ACC = test_acc\n",
    "        CURRENT_CHECKPOINT = 0\n",
    "    else:\n",
    "        CURRENT_CHECKPOINT += 1\n",
    "    \n",
    "    print('time taken:', time.time()-lasttime)\n",
    "    print('epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\\n'%(EPOCH,train_loss,\n",
    "                                                                                          train_acc,test_loss,\n",
    "                                                                                          test_acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
